
import torch
from random import shuffle
import os
import numpy as np
import sys
from PIL import Image
import codecs
from typing import Tuple, Any
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
from copy import deepcopy
from torchvision import transforms, datasets
from torchvision.datasets import CIFAR100, CIFAR10, SVHN, MNIST
from torchvision.datasets.folder import ImageFolder, default_loader
from torchvision.datasets.utils import download_url, check_integrity, verify_str_arg, download_and_extract_archive
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


from torchvision.transforms import functional as Fv
try:
    interpolation = Fv.InterpolationMode.BICUBIC
except:
    interpolation = 3
from PIL import Image

from src.utils_data import Continual_Dataset,logger

dataset2numclass={
    'CIFAR10':10,
    'CIFAR100':100,
    'tiny-imagenet-200':200,
    'ImageNetA': 200,
    'ImageNetR': 200,
    'objectnet': 200,
    'omnibenchmark': 300,
    'vtab': 50
}

def IC_Generate_Batch(data, transform):
    idx, X, y = zip(*data)
    y = torch.LongTensor(y)
    if isinstance(X[0],str):
        X = torch.stack([transform(pil_loader(img)) for img in X],dim=0)
    else:
        X = torch.stack([transform(img) for img in X],dim=0)
    return idx, X, y

class Continual_IC_Dataset(Continual_Dataset):
    '''
        Continual Learning with a Image Classification dataset
    '''
    def __init__(self, dataset, batch_size, class_ft, class_pt, backbone, is_mix_er, seed=None, num_workers=1):

        self.batch_size = batch_size
        self.data_loader = {'train':[], 'test':[], 'dev':[]}
        self.backbone=backbone
        
        self.seed = seed
        
        assert isinstance(dataset,str), 'Only support one dataset'
        self.DATASET_LIST = [dataset]
        
        # Load data
        logger.info("Load train, dev and test set")
        if dataset == 'five_datasets':
            
            # Constant for Continual IC
            self.NUM_TASK = 5
            self.CUR_NUM_CLASS = []
            self.CUR_CLASS = []

            num_label_cnt = 0
            for t_id, one_dataset in enumerate(['SVHN', 'MNIST', 'CIFAR10', 'NotMNIST', 'FashionMNIST']): 

                one_train_dataset, one_dev_dataset, one_test_dataset = self.read_ic_data(one_dataset)
                
                if torch.is_tensor(one_train_dataset.targets):
                    num_labels = len(list(set(one_train_dataset.targets.numpy())))
                else:
                    num_labels = len(list(set(one_train_dataset.targets)))
                self.CUR_NUM_CLASS.append(num_labels)
                self.CUR_CLASS.append(list(range(num_label_cnt,num_label_cnt+num_labels)))

                # Map the target in the dataset to the global index
                one_train_dataset.targets = np.array([num_label_cnt+_targets for _targets in one_train_dataset.targets])
                one_dev_dataset.targets = np.array([num_label_cnt+_targets for _targets in one_dev_dataset.targets])
                one_test_dataset.targets = np.array([num_label_cnt+_targets for _targets in one_test_dataset.targets])

                self.data_loader['train'].append(DataLoader(
                                            dataset=one_train_dataset, 
                                            batch_size=self.batch_size//2 if t_id>0 and is_mix_er else self.batch_size, 
                                            shuffle=True,
                                            num_workers=num_workers))
                self.data_loader['dev'].append(DataLoader(
                                            dataset=one_dev_dataset, 
                                            batch_size=self.batch_size, 
                                            shuffle=False,
                                            num_workers=num_workers))
                self.data_loader['test'].append(DataLoader(
                                            dataset=one_test_dataset, 
                                            batch_size=self.batch_size, 
                                            shuffle=False,
                                            num_workers=num_workers))
                
                num_label_cnt += num_labels

            self.LABEL_LIST = list(range(num_label_cnt))
            self.ACCUM_NUM_CLASS = np.cumsum(self.CUR_NUM_CLASS)
            self.PRE_ACCUM_NUM_CLASS = [0]
            self.PRE_ACCUM_NUM_CLASS.extend(self.ACCUM_NUM_CLASS[:-1])
                
            # Data statistic
            logger.info("Train size %s; Dev size %s; Test size: %s;" % (
                [len(self.data_loader['train'][i].dataset.targets) for i in range(self.NUM_TASK)], 
                [len(self.data_loader['dev'][i].dataset.targets) for i in range(self.NUM_TASK)],
                [len(self.data_loader['test'][i].dataset.targets) for i in range(self.NUM_TASK)]))
            logger.info('Label_list = %s'%str(self.LABEL_LIST)) 

        else:

            assert dataset in dataset2numclass.keys(), 'Not implemented for %s'%(dataset)
            assert (dataset2numclass[dataset]-class_ft)%class_pt==0, 'Invalid setting of class_ft %d and class_pt %d'%(class_ft,class_pt)
            
            # Constant for Continual IC
            self.NUM_TASK = 1+(dataset2numclass[dataset]-class_ft)//class_pt
            self.CUR_NUM_CLASS = [class_ft]+[class_pt]*(self.NUM_TASK-1)
            self.LABEL_LIST = list(range(dataset2numclass[dataset]))
            self.ACCUM_NUM_CLASS = np.cumsum(self.CUR_NUM_CLASS)
            self.PRE_ACCUM_NUM_CLASS = [0]
            self.PRE_ACCUM_NUM_CLASS.extend(self.ACCUM_NUM_CLASS[:-1])
            self.CUR_CLASS = [list(range(self.PRE_ACCUM_NUM_CLASS[task_id],self.ACCUM_NUM_CLASS[task_id]))
                            for task_id in range(self.NUM_TASK)]

            train_dataset, dev_dataset, test_dataset = self.read_ic_data(dataset)

            for t_id in range(self.NUM_TASK):

                train_idxes = []
                for k in range(len(train_dataset.targets)):
                    if int(train_dataset.targets[k]) in self.CUR_CLASS[t_id]:
                        train_idxes.append(k)
                self.data_loader['train'].append(DataLoader(
                                            dataset=Subset(train_dataset,train_idxes), 
                                            batch_size=self.batch_size//2 if t_id>0 and is_mix_er else self.batch_size, 
                                            shuffle=True))
                
                dev_idxes = []
                for k in range(len(dev_dataset.targets)):
                    if int(dev_dataset.targets[k]) in self.CUR_CLASS[t_id]:
                        dev_idxes.append(k)
                self.data_loader['dev'].append(DataLoader(
                                            dataset=Subset(dev_dataset,dev_idxes), 
                                            batch_size=self.batch_size, 
                                            shuffle=False))
                
                test_idxes = []
                for k in range(len(test_dataset.targets)):
                    if int(test_dataset.targets[k]) in self.CUR_CLASS[t_id]:
                        test_idxes.append(k)
                self.data_loader['test'].append(DataLoader(
                                            dataset=Subset(test_dataset,test_idxes), 
                                            batch_size=self.batch_size, 
                                            shuffle=False))

            # Data statistic
            logger.info("Train size %s; Dev size %s; Test size: %s;" % (
                [len(self.data_loader['train'][i].dataset.indices) for i in range(self.NUM_TASK)], 
                [len(self.data_loader['dev'][i].dataset.indices) for i in range(self.NUM_TASK)],
                [len(self.data_loader['test'][i].dataset.indices) for i in range(self.NUM_TASK)]))
            logger.info('Label_list = %s'%str(self.LABEL_LIST)) 

    def get_accum_data_loader(self, task_id, phase):

        accum_idx_list = []

        for t_id in range(0, task_id+1):
            accum_idx_list.extend(self.data_loader[phase][t_id].dataset.indices)

        if phase == 'train':
            return DataLoader(dataset=Subset(self.data_loader[phase][t_id].dataset.dataset, accum_idx_list), 
                                batch_size=self.batch_size, 
                                shuffle=True)
        if phase == 'dev':
            return DataLoader(dataset=Subset(self.data_loader[phase][t_id].dataset.dataset, accum_idx_list), 
                                batch_size=self.batch_size, 
                                shuffle=False)
        if phase == 'test':
            return DataLoader(dataset=Subset(self.data_loader[phase][t_id].dataset.dataset, accum_idx_list), 
                                batch_size=self.batch_size, 
                                shuffle=False)
    
    def read_ic_data(self, dataset):

        root_datapath = './datasets'
        datapath =  os.path.join(os.path.join(root_datapath, dataset))

        transform_train = build_transform(is_train=True, backbone=self.backbone)
        transform_test = build_transform(is_train=False, backbone=self.backbone)

        if dataset in ['ImageNetA','ImageNetR','omnibenchmark','vtab','objectnet']:
            root_datapath = '../datasets/'
            if dataset == 'ImageNetA':
                datapath = os.path.join(root_datapath,'ImageNetA')
            elif dataset == 'ImageNetR':
                datapath = os.path.join(root_datapath,'ImageNetR')
            elif dataset == 'objectnet':
                datapath = os.path.join(root_datapath,'objectnet')
            elif dataset == 'omnibenchmark':
                datapath = os.path.join(root_datapath,'omnibenchmark')
            elif dataset == 'vtab':
                datapath = os.path.join(root_datapath,'vtab')

            train_datapath = os.path.join(datapath,'train')
            test_datapath = os.path.join(datapath,'test')
            dataset_train = ImageDataset(train_datapath, train=True, transform=transform_train)
            dataset_val = ImageDataset(test_datapath, train=False, transform=transform_test)
            dataset_test = ImageDataset(test_datapath, train=False, transform=transform_test)

            # Shuffle
            if self.seed is not None:
                num_class = len(np.unique(dataset_train.targets))
                class_map = list(range(num_class))
                shuffle(class_map)
                dataset_train.targets = np.array([class_map[_tgt] for _tgt in dataset_train.targets])
                dataset_val.targets = np.array([class_map[_tgt] for _tgt in dataset_val.targets])
                dataset_test.targets = np.array([class_map[_tgt] for _tgt in dataset_test.targets])

        elif dataset == 'CIFAR100':
            # Following DyTox
            dataset_train = MyCIFAR100(datapath, train=True, download=True, transform=transform_train)
            dataset_val = MyCIFAR100(datapath, train=False, download=True, transform=transform_test)
            dataset_test = MyCIFAR100(datapath, train=False, download=True, transform=transform_test)
        
        elif dataset == 'CIFAR10':
            dataset_train = MyCIFAR10(root_datapath, train=True, download=True, transform=transform_train)
            dataset_val = MyCIFAR10(root_datapath, train=False, download=True, transform=transform_test)
            dataset_test = MyCIFAR10(root_datapath, train=False, download=True, transform=transform_test)
        
        elif dataset == 'MNIST':
            dataset_train = MNIST_RGB(root_datapath, train=True, download=True, transform=transform_train)
            dataset_val = MNIST_RGB(root_datapath, train=False, download=True, transform=transform_test)
            dataset_test = MNIST_RGB(root_datapath, train=False, download=True, transform=transform_test)
        
        elif dataset == 'FashionMNIST':
            dataset_train = FashionMNIST(root_datapath, train=True, download=True, transform=transform_train)
            dataset_val = FashionMNIST(root_datapath, train=False, download=True, transform=transform_test)
            dataset_test = FashionMNIST(root_datapath, train=False, download=True, transform=transform_test)
        
        elif dataset == 'SVHN':
            dataset_train = SVHN(root_datapath, split='train', download=True, transform=transform_train)
            dataset_val = SVHN(root_datapath, split='test', download=True, transform=transform_test)
            dataset_test = SVHN(root_datapath, split='test', download=True, transform=transform_test)
        
        elif dataset == 'NotMNIST':
            dataset_train = NotMNIST(root_datapath, train=True, download=True, transform=transform_train)
            dataset_val = NotMNIST(root_datapath, train=False, download=True, transform=transform_test)
            dataset_test = NotMNIST(root_datapath, train=False, download=True, transform=transform_test)

        else:
            raise NotImplementedError()

        return dataset_train, dataset_val, dataset_test

def build_transform(is_train, backbone):
    
    if 'vit' in backbone or 'deit' in backbone:

        input_size = 224
        resize_im = input_size > 32
        if is_train:
            scale = (0.05, 1.0)
            ratio = (3. / 4., 4. / 3.)
            transform = transforms.Compose([
                transforms.RandomResizedCrop(input_size, scale=scale, ratio=ratio),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ])
            return transform

        t = []
        if resize_im:
            size = int((256 / 224) * input_size)
            t.append(
                transforms.Resize(size, interpolation=3),  # to maintain same ratio w.r.t. 224 images
            )
            t.append(transforms.CenterCrop(input_size))
        t.append(transforms.ToTensor())
        
        return transforms.Compose(t)

    elif backbone == 'resnet18':
        if is_train:
            return transforms.Compose(
                    [transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408),
                                        (0.2675, 0.2565, 0.2761))])
        else:
            return transforms.Compose(
                    [transforms.ToTensor(), 
                     transforms.Normalize((0.5071, 0.4867, 0.4408),
                                         (0.2675, 0.2565, 0.2761))])
    else:
        raise NotImplementedError()


def split_images_labels(imgs):
    # split trainset.imgs in ImageFolder
    images = []
    labels = []
    for item in imgs:
        images.append(item[0])
        labels.append(item[1])

    return np.array(images), np.array(labels)

class ImageDataset(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.Train = train
        self.root_dir = root
        self.transform = transform

        dset = datasets.ImageFolder(root)

        self.data, self.targets = split_images_labels(dset.imgs)
            

    def __getitem__(self, index: int, is_transform:bool = True) -> Tuple[int, Image.Image, Image.Image]:
        """
        Gets the requested element from the dataset.
        :param index: index of the element to be returned
        :returns: tuple: (image, target) where target is index of the target class.
        """
        img, target = pil_loader(self.data[index]), self.targets[index]

        if not is_transform:
            return index, img, target
        else:
            img = self.transform(img)
            return index, img, target
        
    def __len__(self):

        return len(self.targets)

def pil_loader(path):
    """
    Ref:
    https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
    """
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")

class TinyImageNet(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.Train = train
        self.root_dir = root
        self.transform = transform
        self.train_dir = os.path.join(self.root_dir, "train")
        self.val_dir = os.path.join(self.root_dir, "val")

        if (self.Train):
            self._create_class_idx_dict_train()
        else:
            self._create_class_idx_dict_val()

        self._make_dataset(self.Train)

        words_file = os.path.join(self.root_dir, "words.txt")
        wnids_file = os.path.join(self.root_dir, "wnids.txt")

        self.set_nids = set()

        with open(wnids_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                self.set_nids.add(entry.strip("\n"))

        self.class_to_label = {}
        with open(words_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                words = entry.split("\t")
                if words[0] in self.set_nids:
                    self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0]

    def _create_class_idx_dict_train(self):
        if sys.version_info >= (3, 5):
            classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(train_dir, d))]
        classes = sorted(classes)
        num_images = 0
        for root, dirs, files in os.walk(self.train_dir):
            for f in files:
                if f.endswith(".JPEG"):
                    num_images = num_images + 1

        self.len_dataset = num_images

        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}

    def _create_class_idx_dict_val(self):
        val_image_dir = os.path.join(self.val_dir, "images")
        if sys.version_info >= (3, 5):
            images = [d.name for d in os.scandir(val_image_dir) if d.is_file()]
        else:
            images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(train_dir, d))]
        val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt")
        self.val_img_to_class = {}
        set_of_classes = set()
        with open(val_annotations_file, 'r') as fo:
            entry = fo.readlines()
            for data in entry:
                words = data.split("\t")
                self.val_img_to_class[words[0]] = words[1]
                set_of_classes.add(words[1])

        self.len_dataset = len(list(self.val_img_to_class.keys()))
        classes = sorted(list(set_of_classes))
        # self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}
        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}

    def _make_dataset(self, Train=True):
        self.images = []
        self.targets = []
        if Train:
            img_root_dir = self.train_dir
            list_of_dirs = [target for target in self.class_to_tgt_idx.keys()]
        else:
            img_root_dir = self.val_dir
            list_of_dirs = ["images"]

        for tgt in list_of_dirs:
            dirs = os.path.join(img_root_dir, tgt)
            if not os.path.isdir(dirs):
                continue

            for root, _, files in sorted(os.walk(dirs)):
                for fname in sorted(files):
                    if (fname.endswith(".JPEG")):
                        path = os.path.join(root, fname)
                        if Train:
                            item = (path, self.class_to_tgt_idx[tgt])
                        else:
                            item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]])
                        self.images.append(item)
                        self.targets.append(item[1])

    def return_label(self, idx):
        return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx]

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, idx, is_transform:bool = True):
        img_path, tgt = self.images[idx]
        with open(img_path, 'rb') as f:
            sample = Image.open(img_path)
            sample = sample.convert('RGB')
        if not is_transform:
            return idx, sample, tgt
        else:
            sample = self.transform(sample)
            return idx, sample, tgt

class MyCIFAR100(CIFAR100):
    """
    Overrides the CIFAR100 dataset to change the getitem function.
    """
    def __init__(self, root, train=True, transform=None, download=False) -> None:
        self.root = root
        super(MyCIFAR100, self).__init__(root, train=train, transform=transform, download=download)

    def __getitem__(self, index: int, is_transform:bool = True) -> Tuple[int, Image.Image, Image.Image]:
        """
        Gets the requested element from the dataset.
        :param index: index of the element to be returned
        :returns: tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # to return a PIL Image
        img = Image.fromarray(img, mode='RGB')
        if not is_transform:
            return index, img, target
        else:
            img = self.transform(img)
            return index, img, target
        
class MyCIFAR10(CIFAR10):
    """
    Overrides the CIFAR10 dataset to change the getitem function.
    """
    def __init__(self, root, train=True, transform=None, download=False) -> None:
        self.root = root
        super(MyCIFAR10, self).__init__(root, train=train, transform=transform, download=download)

    def __getitem__(self, index: int, is_transform:bool = True) -> Tuple[int, Image.Image, Image.Image]:
        """
        Gets the requested element from the dataset.
        :param index: index of the element to be returned
        :returns: tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # to return a PIL Image
        img = Image.fromarray(img, mode='RGB')
        if not is_transform:
            return index, img, target
        else:
            img = self.transform(img)
            return index, img, target

class MNIST_RGB(MNIST):

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(MNIST_RGB, self).__init__(root, transform=transform, target_transform=target_transform, download=download)
        self.train = train  # training set or test set

        if self._check_legacy_exist():
            self.data, self.targets = self._load_legacy_data()
            return

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

        self.data, self.targets = self._load_data()

    def _check_legacy_exist(self):
        processed_folder_exists = os.path.exists(self.processed_folder)
        if not processed_folder_exists:
            return False

        return all(
            check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
        )

    def _load_legacy_data(self):
        # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
        # directly.
        data_file = self.training_file if self.train else self.test_file
        return torch.load(os.path.join(self.processed_folder, data_file))

    def _load_data(self):
        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
        data = read_image_file(os.path.join(self.raw_folder, image_file))

        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
        targets = read_label_file(os.path.join(self.raw_folder, label_file))

        return data, targets

    def __getitem__(self, index: int, is_transform:bool = True) -> Tuple[Any, Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        try:
            img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
        except:
            pass

        if self.transform is not None and is_transform:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target

class FashionMNIST(MNIST_RGB):
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
            and  ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]

    resources = [
        ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
        ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
        ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
        ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
    ]
    classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

    def __getitem__(self, index: int, is_transform:bool = True) -> Tuple[Any, Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        try:
            img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
        except:
            pass

        if self.transform is not None and is_transform:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target

class NotMNIST(MNIST_RGB):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform=target_transform
        self.train = train

        self.url = 'https://github.com/facebookresearch/Adversarial-Continual-Learning/raw/main/data/notMNIST.zip'
        self.filename = 'notMNIST.zip'

        fpath = os.path.join(root, self.filename)
        if not os.path.isfile(fpath):
            if not download:
               raise RuntimeError('Dataset not found. You can use download=True to download it')
            else:
                print('Downloading from '+self.url)
                download_url(self.url, root, filename=self.filename)

        import zipfile
        zip_ref = zipfile.ZipFile(fpath, 'r')
        zip_ref.extractall(root)
        zip_ref.close()

        if self.train:
            fpath = os.path.join(root, 'notMNIST', 'Train')

        else:
            fpath = os.path.join(root, 'notMNIST', 'Test')


        X, Y = [], []
        folders = os.listdir(fpath)

        for folder in folders:
            folder_path = os.path.join(fpath, folder)
            for ims in os.listdir(folder_path):
                try:
                    img_path = os.path.join(folder_path, ims)
                    X.append(np.array(Image.open(img_path).convert('RGB')))
                    Y.append(ord(folder) - 65)  # Folders are A-J so labels will be 0-9
                except:
                    print("File {}/{} is broken".format(folder, ims))
        self.data = np.array(X)
        self.targets = Y

    def __getitem__(self, index: int, is_transform:bool = True) -> Tuple[Any, Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        try:
            img = Image.fromarray(img)
        except:
            pass

        if self.transform is not None and is_transform:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target

class SVHN(SVHN):
    def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
        super(SVHN, self).__init__(root, split=split, transform=transform, target_transform=target_transform, download=download)
        self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
        self.url = self.split_list[split][0]
        self.filename = self.split_list[split][1]
        self.file_md5 = self.split_list[split][2]

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

        # import here rather than at top of file because this is
        # an optional dependency for torchvision
        import scipy.io as sio

        # reading(loading) mat file as array
        loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))

        self.data = loaded_mat["X"]
        # loading from the .mat file gives an np array of type np.uint8
        # converting to np.int64, so that we have a LongTensor after
        # the conversion from the numpy array
        # the squeeze is needed to obtain a 1D tensor
        self.targets = loaded_mat["y"].astype(np.int64).squeeze()

        # the svhn dataset assigns the class label "10" to the digit 0
        # this makes it inconsistent with several loss functions
        # which expect the class labels to be in the range [0, C-1]
        np.place(self.targets, self.targets == 10, 0)
        self.data = np.transpose(self.data, (3, 2, 0, 1))
        self.classes = np.unique(self.targets)

    def __getitem__(self, index: int, is_transform:bool = True) -> Tuple[Any, Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.transpose(img, (1, 2, 0)))

        if self.transform is not None and is_transform:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target

    def __len__(self) -> int:
        return len(self.data)

    def _check_integrity(self) -> bool:
        root = self.root
        md5 = self.split_list[self.split][2]
        fpath = os.path.join(root, self.filename)
        return check_integrity(fpath, md5)

    def download(self) -> None:
        md5 = self.split_list[self.split][2]
        download_url(self.url, self.root, self.filename, md5)

    def extra_repr(self) -> str:
        return "Split: {split}".format(**self.__dict__)

def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)


def open_maybe_compressed_file(path):
    """Return a file object that possibly decompresses 'path' on the fly.
       Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
    """
    if not isinstance(path, str):
        return path
    if path.endswith('.gz'):
        import gzip
        return gzip.open(path, 'rb')
    if path.endswith('.xz'):
        import lzma
        return lzma.open(path, 'rb')
    return open(path, 'rb')


def read_sn3_pascalvincent_tensor(path, strict=True):
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
       Argument may be a filename, compressed filename, or file object.
    """
    # typemap
    if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
        read_sn3_pascalvincent_tensor.typemap = {
            8: (torch.uint8, np.uint8, np.uint8),
            9: (torch.int8, np.int8, np.int8),
            11: (torch.int16, np.dtype('>i2'), 'i2'),
            12: (torch.int32, np.dtype('>i4'), 'i4'),
            13: (torch.float32, np.dtype('>f4'), 'f4'),
            14: (torch.float64, np.dtype('>f8'), 'f8')}
    # read
    with open_maybe_compressed_file(path) as f:
        data = f.read()
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    ty = magic // 256
    assert nd >= 1 and nd <= 3
    assert ty >= 8 and ty <= 14
    m = read_sn3_pascalvincent_tensor.typemap[ty]
    s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
    parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
    assert parsed.shape[0] == np.prod(s) or not strict
    return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


def read_label_file(path):
    with open(path, 'rb') as f:
        x = read_sn3_pascalvincent_tensor(f, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 1)
    return x.long()


def read_image_file(path):
    with open(path, 'rb') as f:
        x = read_sn3_pascalvincent_tensor(f, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 3)
    return x

if __name__ == "__main__":
    pass
